#!/usr/bin/env python3

#
# This file is part of the plotting scripts supporting the CacheSC library
# (https://github.com/Miro-H/CacheSC), which implements Prime+Probe attacks on
# virtually and physically indexed caches.
#
# Copyright (C) 2020  Miro Haller
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Contact: miro.haller@alumni.ethz.ch
#
# Short description of this file:
# Plots cache side-channel timing observations from a log file that has a certain
# structure (see parser.py or the log file generated by the demo code)
#

# Next to lines are to use matplotlib without X server (display)
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import argparse
import scipy.stats.mstats as stats

from logger import Logger
from parser import Parser


# Constants
TRIM_HIGH_PERCENTAGE    = 0.05
TRIM_LOW_PERCENTAGE     = 0


# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("log_file", help="path to log file to parse")
parser.add_argument("-o", "--output_folder",
                    help="path to folder for the produced plots",
                    default="./plots")
parser.add_argument("--ylims", help="fix y axis of plot, tuple (y_min, y_max)",
                    default="tuple()")
parser.add_argument("-n", "--normalize", help="normalize samples using an additional "
                                              "data set with uninfluenced data points",
                    action="store_true")
parser.add_argument("-t", "--transpose", help="transpose data set, i.e. average over"
                                              "the i-th entries of each sample",
                    action="store_true")
parser.add_argument("-v", "--verbose", help="print debug output", action="store_true")

args = parser.parse_args()

log_file_path = args.log_file
output_folder = args.output_folder
y_lims        = eval(args.ylims)
do_normalize  = args.normalize
do_transpose  = args.transpose
verbose       = args.verbose

log_file_name = log_file_path
if "/" in log_file_path:
    log_file_name = log_file_name.rsplit("/", 1)[1]

logger = Logger("plot")
if verbose:
    logger.set_verbose()

logger.title("Start plotting")

# Parse log file
parser = Parser()
samples, bl_samples, meta_data = parser.parse(log_file_path, do_normalize)

logger.line(f"Compute statistics")

# Prepare data
def tmean(arr):
    return stats.trimmed_mean(arr, limits=(TRIM_LOW_PERCENTAGE, TRIM_HIGH_PERCENTAGE))

def tstd(arr):
    return stats.trimmed_std(arr, limits=(TRIM_LOW_PERCENTAGE, TRIM_HIGH_PERCENTAGE))



if do_transpose:
    samples = list(map(list, zip(*samples)))

avg_per_entry   = list(map(tmean, samples))
std_per_entry   = list(map(tstd, samples))


idx = 0
logger.debug(sorted(samples[idx]))
logger.debug(f"tmean: {tmean(samples[idx])}")

if do_normalize:
    logger.line("Normalize samples")
    if do_transpose:
        bl_samples = list(map(list, zip(*bl_samples)))
    avg_bl_per_entry = list(map(tmean, bl_samples))

    for i in range(len(avg_bl_per_entry)):
        avg_per_entry[i] -= avg_bl_per_entry[i]
    logger.debug(f"baseline: {avg_bl_per_entry[idx]}")
    logger.debug(f"normalized: {avg_per_entry[idx]}")

logger.line(f"Plot samples")

# Plot data
fig, ax = plt.subplots(figsize=(9,5), dpi=200)

if y_lims:
    ax.set_ylim(*y_lims)

x_vals = list(range(len(avg_per_entry)))
ax.errorbar(x_vals, avg_per_entry, label=meta_data.legend,
            fmt='-o', ms=5, capthick=.5, capsize=3)

# General settings
ax.set_title(f"Cache Side-Channel ({meta_data.samples_cnt} samples)")
ax.set_xlabel(meta_data.x_axis_label)
ax.set_ylabel(meta_data.y_axis_label)

if meta_data.legend:
    ax.legend(loc=1)

footnote = f"Trimming data to ({TRIM_LOW_PERCENTAGE}, {1 - TRIM_HIGH_PERCENTAGE})"
plt.text(0.75, 0.01, footnote, transform=plt.gcf().transFigure)

plot_name = log_file_name.rsplit(".", 1)[0] + "_plot.png"
plot_path = f"{output_folder}/{plot_name}"
logger.line(f"Save plot to {plot_path}")
plt.savefig(plot_path)

logger.line(f"Done")
